package dbext

import (
	"fmt"
	"gitee.com/tjloved/tool/helper"
	"go/format"
	"os"
	"strings"
)

// ModelGenerator 数据库模型生成器
type ModelGenerator struct {
	TableInfoGetter *TableInfoGetter
	Directory       string
	importPackages  []string
}

func NewModelGenerator(ConnectionName string, directory string) *ModelGenerator {
	tableInfoGetter := NewTableInfoGetter(ConnectionName)
	return &ModelGenerator{
		TableInfoGetter: tableInfoGetter,
		Directory:       directory,
	}
}

func (mg ModelGenerator) getTemplate() string {
	modelTmp := `
package model

{Import}

type {StructName} struct{
{StructField}
}

func ({ShortStructName} {StructName}) TableName() string {
	return "{TableName}"
}
`
	return modelTmp
}

func (mg *ModelGenerator) GenModel(force bool, tableNames ...string) {
	tables := mg.TableInfoGetter.GetTables(tableNames)
	for _, table := range tables {
		modelTmp := mg.getTemplate()
		modelTmp = strings.ReplaceAll(modelTmp, "{StructName}", helper.StringToCamelCase(table.Name))
		modelTmp = strings.ReplaceAll(modelTmp, "{TableName}", table.Name)
		modelTmp = strings.ReplaceAll(modelTmp, "{ShortStructName}", helper.StringToInitialsAbbreviation(helper.StringToCamelCase(table.Name)))
		columns := mg.TableInfoGetter.GetColumns(table.Name)
		var columnStrSlice []string
		for _, column := range columns {
			columnStrSlice = append(columnStrSlice, column.ToStructField())
			if column.IsTime() {
				mg.setImportPackages("\"time\"")
			}
		}
		modelTmp = strings.ReplaceAll(modelTmp, "{StructField}", strings.Join(columnStrSlice, "\n"))
		modelTmp = strings.ReplaceAll(modelTmp, "{Import}", mg.getImportPackages())

		// 使用gofmt包格式化代码
		formattedCode, err := format.Source([]byte(modelTmp))
		if err != nil {
			fmt.Println("格式化错误:", err)
			return
		}
		mg.importPackages = []string{}
		mg.output(formattedCode, table.Name, force)
	}
}

func (mg *ModelGenerator) getImportPackages() string {
	importPackagesTmp := `
import (
{importPackages}
)
`
	var importPackagesStr []string
	importPackages, err := helper.SliceUnique(mg.importPackages)
	if err == nil {
		for _, importPackage := range importPackages.([]string) {
			importPackagesStr = append(importPackagesStr, " "+importPackage)
		}
		if len(importPackagesStr) > 0 {
			return strings.ReplaceAll(importPackagesTmp, "{importPackages}", strings.Join(importPackagesStr, "\n"))
		}
	}
	return ""
}

func (mg *ModelGenerator) setImportPackages(pkNames ...string) {
	mg.importPackages = append(mg.importPackages, pkNames...)
}

func (mg *ModelGenerator) getFullFileName(filename string) string {
	return mg.Directory + "/" + filename + ".go"
}

func (mg *ModelGenerator) output(content []byte, filename string, force bool) {
	err := os.MkdirAll(mg.Directory, os.ModePerm)
	if err != nil {
		panic(err)
	}

	if !force {
		if _, err := os.Stat(mg.getFullFileName(filename)); err == nil {
			fmt.Println("文件已存在:", mg.getFullFileName(filename))
			return
		}
	}

	file, err := os.OpenFile(mg.getFullFileName(filename), os.O_CREATE|os.O_TRUNC|os.O_RDWR, os.ModePerm)
	if err != nil {
		panic(err)
	}
	_, err = file.Write(content)
	if err != nil {
		panic(err)
	}
	_ = file.Close()
}
